-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Vector] Add vector.shuffle fold for poison inputs
#125608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) Changes#124863 added folding support for poison indices to Full diff: https://github.com/llvm/llvm-project/pull/125608.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 93f89eda2da5a6..8d5691f38f273c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -26,7 +26,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
@@ -42,7 +41,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/ADT/bit.h"
#include <cassert>
#include <cstdint>
@@ -2696,25 +2694,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
if (!v1Attr || !v2Attr)
return {};
+ // Fold shuffle poison, poison -> poison.
+ bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
+ bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
+ if (isV1Poison && isV2Poison)
+ return ub::PoisonAttr::get(getContext());
+
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
// manipulation.
if (v1Type.getRank() != 1)
return {};
- int64_t v1Size = v1Type.getDimSize(0);
+ // Poison input attributes need special handling as they are not
+ // DenseElementsAttr. If an index is poison, we select the first element of
+ // the first non-poison input.
+ SmallVector<Attribute> v1Elements, v2Elements;
+ Attribute poisonElement;
+ if (!isV2Poison) {
+ v2Elements =
+ to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
+ poisonElement = v2Elements[0];
+ }
+ if (!isV1Poison) {
+ v1Elements =
+ to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
+ poisonElement = v1Elements[0];
+ }
SmallVector<Attribute> results;
- auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
- auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+ int64_t v1Size = v1Type.getDimSize(0);
for (int64_t maskIdx : mask) {
Attribute indexedElm;
- // Select v1[0] for poison indices.
// TODO: Return a partial poison vector when supported by the UB dialect.
if (maskIdx == ShuffleOp::kPoisonIndex) {
- indexedElm = v1Elements[0];
+ indexedElm = poisonElement;
} else {
- indexedElm =
- maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
+ if (maskIdx < v1Size)
+ indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
+ else
+ indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
}
results.push_back(indexedElm);
@@ -3332,13 +3350,15 @@ class InsertStridedSliceConstantFolder final
!destVector.hasOneUse())
return failure();
- auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
-
TypedValue<VectorType> sourceValue = op.getSource();
Attribute sourceCst;
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
return failure();
+ // TODO: Support poison.
+ if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
+ return failure();
+
// TODO: Handle non-unit strides when they become available.
if (op.hasNonUnitStrides())
return failure();
@@ -3355,6 +3375,7 @@ class InsertStridedSliceConstantFolder final
// increasing linearized position indices.
// Because the destination may have higher dimensionality then the slice,
// we keep track of two overlapping sets of positions and offsets.
+ auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
auto sliceValuesIt = denseSlice.value_begin<Attribute>();
auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6858f0d56e6412..65c3ab264283d2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2023,6 +2023,45 @@ func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
// -----
+// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = ub.poison : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> {
+ %v0 = ub.poison : vector<3xi32>
+ %v1 = ub.poison : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_lhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = arith.constant dense<[5, 4, 5, 5]> : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_lhs_poison() -> vector<4xi32> {
+ %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
+ %v1 = ub.poison : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_rhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = arith.constant dense<[2, 2, 0, 1]> : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
+ %v0 = ub.poison : vector<3xi32>
+ %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
|
We recently added folding support for poison indices to `vector.shuffle`. This PR adds support for folding poison inputs.
ece1c6d to
46a4887
Compare
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Left a couple of optional nits.
| %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32> | ||
| %v1 = ub.poison : vector<3xi32> | ||
| %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] There's a value and index == 5, so it's not obvious that the first element of %v0 is in any way significant. Perhaps use some more distinct number? (e.g. 123).
| %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32> | |
| %v1 = ub.poison : vector<3xi32> | |
| %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32> | |
| %v0 = arith.constant dense<[123, 4, 3]> : vector<3xi32> | |
| %v1 = ub.poison : vector<3xi32> | |
| %shuffle = vector.shuffle %v0, %v1 [3, 1, 123, 4] : vector<3xi32>, vector<3xi32> |
I appreciate that this is obvious right now, but lets also cater for our future selves :)
| // Poison input attributes need special handling as they are not | ||
| // DenseElementsAttr. If an index is poison, we select the first element of | ||
| // the first non-poison input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] To me this is a fairly significant (and not immediately intuitive) part of the design. Perhaps move above the signature?
Also, is this based on some prior-art? Just curious, this does make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow the prior-art part. Do you mean why we pick the first element of the first non-poison input? Poison is basically UB so given that we can't represent a partially poison vector we just make a random decision, which is ok as part of the UB behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's valid to substitute poison with an arbitrary value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's valid to substitute poison with an arbitrary value
Sure, but we are selecting a specific "arbitrary value" :)
Not sure I follow the prior-art part.
I was just curious whether there's any rationale behind this specific option. For example, something else in LLVM or MLIR makes similar choice?
Basically, what I'm missing is "why would we select the first element"? Something along the lines would be helpful:
I doesn't matter what we select, but we need to make a choice. We choose the first element.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but we are selecting a specific "arbitrary value" :)
??? In this context, arbitrary is synonymous to non-deterministics, as in: absolutely any value will do and the choice doesn't have to be fair by any definition of fair.
|
It looks like Github has been in "Processing updates" stage (see top of the page) for almost an hour... Weird... |
llvm#124863 added folding support for poison indices to `vector.shuffle`. This PR adds support for folding `vector.shuffle` ops with one or two poison input vectors.
#124863 added folding support for poison indices to
vector.shuffle. This PR adds support for foldingvector.shuffleops with one or two poison input vectors.